{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "08f5ad56",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.output_parsers import StructuredOutputParser, ResponseSchema\n",
    "from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate\n",
    "from langchain.llms import OpenAI\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "import pandas as pd\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "ed951cdf",
   "metadata": {
    "hide_input": false
   },
   "outputs": [],
   "source": [
    "openai_api_key = '...'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "bf74bc91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Temp = 0 so that we get clean information without a lot of creativity\n",
    "chat_model = ChatOpenAI(temperature=0, openai_api_key=openai_api_key, max_tokens=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "5845bbed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# How you would like your response structured. This is basically a fancy prompt template\n",
    "response_schemas = [\n",
    "    ResponseSchema(name=\"input_industry\", description=\"This is the input_industry from the user\"),\n",
    "    ResponseSchema(name=\"standardized_industry\", description=\"This is the industry you feel is most closely matched to the users input\"),\n",
    "    ResponseSchema(name=\"match_score\",  description=\"A score 0-100 of how close you think the match is between user input and your match\")\n",
    "]\n",
    "\n",
    "# How you would like to parse your output\n",
    "output_parser = StructuredOutputParser.from_response_schemas(response_schemas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "191d2184",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The output should be a markdown code snippet formatted in the following schema:\n",
      "\n",
      "```json\n",
      "{\n",
      "\t\"input_industry\": string  // This is the input_industry from the user\n",
      "\t\"standarized_industry\": string  // This is the industry you feel is most closely matched to the users input\n",
      "\t\"match_score\": string  // A score 0-100 of how close you think the match is between user input and your match\n",
      "}\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "# See the prompt template you created for formatting\n",
    "format_instructions = output_parser.get_format_instructions()\n",
    "print (output_parser.get_format_instructions())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "4f7e3eab",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"\n",
    "You will be given a series of industry names from a user.\n",
    "Find the best corresponding match on the list of standardized names.\n",
    "The closest match will be the one with the closest semantic meaning. Not just string similarity.\n",
    "\n",
    "{format_instructions}\n",
    "\n",
    "Wrap your final output with closed and open brackets (a list of json objects)\n",
    "\n",
    "input_industry INPUT:\n",
    "{user_industries}\n",
    "\n",
    "STANDARDIZED INDUSTRIES:\n",
    "{standardized_industries}\n",
    "\n",
    "YOUR RESPONSE:\n",
    "\"\"\"\n",
    "\n",
    "prompt = ChatPromptTemplate(\n",
    "    messages=[\n",
    "        HumanMessagePromptTemplate.from_template(template)  \n",
    "    ],\n",
    "    input_variables=[\"user_industries\", \"standardized_industries\"],\n",
    "    partial_variables={\"format_instructions\": format_instructions}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "7d074208",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Corporate Services, Recreation & Travel, Legal, Wellness & Fitness, Entertainment, Consumer Goods, Design, Arts, Manufacturing, Finance, Health Care, Construction, Nonprofit, Real Estate, Software & IT Services, Hardware & Networking, Agriculture, Education, Public Administration, Transportation & Logistics, Public Safety, Media & Communications, Energy & Mining, Retail'"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Get your standardized names. You can swap this out with whatever list you want!\n",
    "df = pd.read_csv('../data/LinkedInIndustries.csv')\n",
    "standardized_industries = \", \".join(df['Industry'].values)\n",
    "standardized_industries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "b162ff8c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 1 message(s)\n",
      "Type: <class 'langchain.schema.HumanMessage'>\n",
      "---------------------------\n",
      "\n",
      "You will be given a series of industry names from a user.\n",
      "Find the best corresponding match on the list of standardized names.\n",
      "The closest match will be the one with the closest semantic meaning. Not just string similarity.\n",
      "\n",
      "The output should be a markdown code snippet formatted in the following schema:\n",
      "\n",
      "```json\n",
      "{\n",
      "\t\"input_industry\": string  // This is the input_industry from the user\n",
      "\t\"standarized_industry\": string  // This is the industry you feel is most closely matched to the users input\n",
      "\t\"match_score\": string  // A score 0-100 of how close you think the match is between user input and your match\n",
      "}\n",
      "```\n",
      "\n",
      "Wrap your final output with closed and open brackets (a list of json objects)\n",
      "\n",
      "input_industry INPUT:\n",
      "air LineZ, airline, aviation, planes that fly, farming, bread, wifi networks, twitter media agency\n",
      "\n",
      "STANDARDIZED INDUSTRIES:\n",
      "Corporate Services, Recreation & Travel, Legal, Wellness & Fitness, Entertainment, Consumer Goods, Design, Arts, Manufacturing, Finance, Health Care, Construction, Nonprofit, Real Estate, Software & IT Services, Hardware & Networking, Agriculture, Education, Public Administration, Transportation & Logistics, Public Safety, Media & Communications, Energy & Mining, Retail\n",
      "\n",
      "YOUR RESPONSE:\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Your user input\n",
    "\n",
    "user_input = \"air LineZ, airline, aviation, planes that fly, farming, bread, wifi networks, twitter media agency\"\n",
    "\n",
    "_input = prompt.format_prompt(user_industries=user_input, standardized_industries=standardized_industries)\n",
    "\n",
    "\n",
    "print (f\"There are {len(_input.messages)} message(s)\")\n",
    "print (f\"Type: {type(_input.messages[0])}\")\n",
    "print (\"---------------------------\")\n",
    "print (_input.messages[0].content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "559c27cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = chat_model(_input.to_messages())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "6f375f20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'langchain.schema.AIMessage'>\n",
      "\n",
      "\n",
      "[\n",
      "\t{\n",
      "\t\t\"input_industry\": \"air LineZ\",\n",
      "\t\t\"standarized_industry\": \"Transportation & Logistics\",\n",
      "\t\t\"match_score\": \"80\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"airline\",\n",
      "\t\t\"standarized_industry\": \"Transportation & Logistics\",\n",
      "\t\t\"match_score\": \"90\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"aviation\",\n",
      "\t\t\"standarized_industry\": \"Transportation & Logistics\",\n",
      "\t\t\"match_score\": \"95\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"planes that fly\",\n",
      "\t\t\"standarized_industry\": \"Transportation & Logistics\",\n",
      "\t\t\"match_score\": \"85\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"farming\",\n",
      "\t\t\"standarized_industry\": \"Agriculture\",\n",
      "\t\t\"match_score\": \"90\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"bread\",\n",
      "\t\t\"standarized_industry\": \"Consumer Goods\",\n",
      "\t\t\"match_score\": \"80\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"wifi networks\",\n",
      "\t\t\"standarized_industry\": \"Hardware & Networking\",\n",
      "\t\t\"match_score\": \"95\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"twitter media agency\",\n",
      "\t\t\"standarized_industry\": \"Media & Communications\",\n",
      "\t\t\"match_score\": \"90\"\n",
      "\t}\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "print (type(output))\n",
    "print (output.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "f2571dd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"```json\" in output.content:\n",
    "    json_string = output.content.split(\"```json\")[1].strip()\n",
    "else:\n",
    "    json_string = output.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "f3332c82",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "[\n",
      "\t{\n",
      "\t\t\"input_industry\": \"air LineZ\",\n",
      "\t\t\"standarized_industry\": \"Transportation & Logistics\",\n",
      "\t\t\"match_score\": \"80\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"airline\",\n",
      "\t\t\"standarized_industry\": \"Transportation & Logistics\",\n",
      "\t\t\"match_score\": \"90\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"aviation\",\n",
      "\t\t\"standarized_industry\": \"Transportation & Logistics\",\n",
      "\t\t\"match_score\": \"95\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"planes that fly\",\n",
      "\t\t\"standarized_industry\": \"Transportation & Logistics\",\n",
      "\t\t\"match_score\": \"85\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"farming\",\n",
      "\t\t\"standarized_industry\": \"Agriculture\",\n",
      "\t\t\"match_score\": \"90\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"bread\",\n",
      "\t\t\"standarized_industry\": \"Consumer Goods\",\n",
      "\t\t\"match_score\": \"80\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"wifi networks\",\n",
      "\t\t\"standarized_industry\": \"Hardware & Networking\",\n",
      "\t\t\"match_score\": \"95\"\n",
      "\t},\n",
      "\t{\n",
      "\t\t\"input_industry\": \"twitter media agency\",\n",
      "\t\t\"standarized_industry\": \"Media & Communications\",\n",
      "\t\t\"match_score\": \"90\"\n",
      "\t}\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "print(output.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af5ef426",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "9e8a0819",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'input_industry': 'air LineZ',\n",
       "  'standarized_industry': 'Transportation & Logistics',\n",
       "  'match_score': '80'},\n",
       " {'input_industry': 'airline',\n",
       "  'standarized_industry': 'Transportation & Logistics',\n",
       "  'match_score': '90'},\n",
       " {'input_industry': 'aviation',\n",
       "  'standarized_industry': 'Transportation & Logistics',\n",
       "  'match_score': '95'},\n",
       " {'input_industry': 'planes that fly',\n",
       "  'standarized_industry': 'Transportation & Logistics',\n",
       "  'match_score': '85'},\n",
       " {'input_industry': 'farming',\n",
       "  'standarized_industry': 'Agriculture',\n",
       "  'match_score': '90'},\n",
       " {'input_industry': 'bread',\n",
       "  'standarized_industry': 'Consumer Goods',\n",
       "  'match_score': '80'},\n",
       " {'input_industry': 'wifi networks',\n",
       "  'standarized_industry': 'Hardware & Networking',\n",
       "  'match_score': '95'},\n",
       " {'input_industry': 'twitter media agency',\n",
       "  'standarized_industry': 'Media & Communications',\n",
       "  'match_score': '90'}]"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# output_parser.parse(output.content) Ideally this works but not in all cases\n",
    "structured_data = json.loads(output.content)\n",
    "structured_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "34fcf70c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>input_industry</th>\n",
       "      <th>standarized_industry</th>\n",
       "      <th>match_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>air LineZ</td>\n",
       "      <td>Transportation &amp; Logistics</td>\n",
       "      <td>80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>airline</td>\n",
       "      <td>Transportation &amp; Logistics</td>\n",
       "      <td>90</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>aviation</td>\n",
       "      <td>Transportation &amp; Logistics</td>\n",
       "      <td>95</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>planes that fly</td>\n",
       "      <td>Transportation &amp; Logistics</td>\n",
       "      <td>85</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>farming</td>\n",
       "      <td>Agriculture</td>\n",
       "      <td>90</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>bread</td>\n",
       "      <td>Consumer Goods</td>\n",
       "      <td>80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>wifi networks</td>\n",
       "      <td>Hardware &amp; Networking</td>\n",
       "      <td>95</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>twitter media agency</td>\n",
       "      <td>Media &amp; Communications</td>\n",
       "      <td>90</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         input_industry        standarized_industry match_score\n",
       "0             air LineZ  Transportation & Logistics          80\n",
       "1               airline  Transportation & Logistics          90\n",
       "2              aviation  Transportation & Logistics          95\n",
       "3       planes that fly  Transportation & Logistics          85\n",
       "4               farming                 Agriculture          90\n",
       "5                 bread              Consumer Goods          80\n",
       "6         wifi networks       Hardware & Networking          95\n",
       "7  twitter media agency      Media & Communications          90"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(structured_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b175a51a",
   "metadata": {},
   "source": [
    "#### To Do\n",
    "1. Look at new incoming industries from the user\n",
    "2. Match against your data base of values you've already mapped\n",
    "3. For existing ones, save an API call and get the result from the data base\n",
    "4. For new ones, batch them together for your LLM to return back to you"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
